#import train.train as train
from trainMFAR import train
from data import GameDataMeta
import os.path
# import bogota.data
import argparse
import time
import json
import numpy as np
# import mobdata


def kfold(fold_function, start_fold=0, end_fold=10):
    for i in range(start_fold, end_fold):
        print ('STARTING FOLD: %d' % (i + 1))
        t = time.time()
        fold_function(i)
        t = time.time() - t
        print ('Fold %d complete in %f seconds' % (i + 1, t))
        print ('*' * 100)


def parse_args():
    parser = argparse.ArgumentParser(description="K-fold cross validation")
    parser.add_argument('--start_fold', default=0, type=int)
    parser.add_argument('--end_fold', default=10, type=int)

    #parser.add_argument('--resume', dest='resume', action='store_true')
    #parser.set_default(resume=False)
    parser.add_argument('--path', default='')
    parser.add_argument('--json', default=None,
        help="Path of json file describing the options of the experiment")
    return parser.parse_args()



def build_fold_function(options, new_experiment=False, resume=False):
    if not os.path.exists("./test/best_loss"):
        os.makedirs("./test/best_loss")
    # output_file = options.get('path', './') + options.get('name', 'test') + '.csv'
    output_file = options.get('save_path') + "test/" + options.get('name', 'test') + '.csv'
    # best_loss_filename = options.get('path', './') + "best_loss/" + options.get('name', 'test') + '.csv'
    best_loss_filename = options.get('save_path') + "test/best_loss/" + options.get('name', 'test') + '.csv'
    # par_file = options.get('path', './') + options.get('name', 'test') + '_%d_par.json'
    par_file = options.get('save_path') + "test/" + options.get('name', 'test') + '_%d_par.json'
    dataset_name = options.get('dataset', 'all9')
    seed = options.get('seed', 12)
    if not os.path.isfile(output_file):
        with open(output_file, 'w') as f:
            f.write('Data: %s, seed: %d\n' % (dataset_name, seed))
            f.write(','.join(['fold', 'seed', 'train', 'valid', 'test']) + '\n')

    def fold_function(k):  # k is the fold index
        data = GameDataMeta('unique127.csv', 1.)
        train_data, test_data = data.train_test(k, seed=seed)
        options['fold'] = k
        llk, best_par = train(options, [train_data.datalist(), test_data.datalist()], True)
        print ("LLK: ", llk)

        for kk, vv in list(best_par.items()):
            temp = vv.tolist()
            del best_par[kk]
            best_par[kk] = temp
        log_fold(best_loss_filename, llk, k, options.get('model_seed', -3))

        with open(par_file % k, 'w') as f:
            json.dump(best_par, f)

    return fold_function


def log_fold(log_file_name, llk, fold, model_seed, llk_start=None):
    log_file_name = log_file_name.replace(".csv", "_out.csv")
    with open(log_file_name, 'a') as f:
        if llk_start is not None:
            log = [fold] + [model_seed] + list(llk) + list(llk_start)
        else:
            log = [fold] + [model_seed] + list(llk)
        f.write(','.join(str_lst(log)) + '\n')

def str_lst(x):
    return [str(i) for i in x]

DEFAULT_OPTIONS = {'name': 'test',
                   'save_path': './',
                   'manager_units': [50, 50],
                   'expert1_units': [1, 1],
                   'expert2_units': [1, 1],
                   'expert3_units': [1],
                   'expert4_units': [1, 1],
                   'expert5_units': [1],
                   'activ': 'relu',
                   'pooling': True,
                   'batch_size': None,
                   'ar_layers': 1,
                   'dropout': False,
                   'l1': 0.01,
                   'l2': 0.0,
                   'pooling_activ': 'max',
                   'opt': 'adam',
                   'max_itr': 1,
                   'model_seed': 123,
                   'objective': 'nll'}

def main():
    args = parse_args()
    if args.json is not None:
        options = json.load(open(args.json))
        print ("OPTIONS FROM JSON LOADED SUCCESSFULLY.")
        if args.path != '':
            options['path'] = args.path
    else:
        options = DEFAULT_OPTIONS

    fold_function = build_fold_function(options, args.start_fold==0, False)
    kfold(fold_function, args.start_fold, args.end_fold)

if __name__ == '__main__':
    main()
